# losses.py  — Hungarian-based label alignment (fast for K>=8)

import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.optimize import linear_sum_assignment

# ======= Loss (cross-entropy) =======
criterion = nn.CrossEntropyLoss()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor


# ------------------------
# Utilities
# ------------------------
def _to_1d_numpy(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    return np.asarray(x).flatten()


def from_scores_to_labels_multiclass_batch(pred):
    """
    pred: (B, N, K) logits/scores
    return: (B, N) argmax labels
    """
    if hasattr(pred, "detach"):
        pred = pred.detach().cpu().numpy()
    pred = np.asarray(pred)
    return np.argmax(pred, axis=2).astype(int)


def compute_ari_nmi(pred, true):
    """Return ARI/NMI (both are permutation-invariant)."""
    y_pred = _to_1d_numpy(pred)
    y_true = _to_1d_numpy(true)
    ari = adjusted_rand_score(y_true, y_pred)
    nmi = normalized_mutual_info_score(y_true, y_pred)
    return ari, nmi


def compute_accuracy_multiclass_batch(labels_pred_1d, labels_true_1d):
    return float(np.mean(np.asarray(labels_pred_1d) == np.asarray(labels_true_1d)))


# ------------------------
# Hungarian matching
# ------------------------
def _confusion_matrix(true, pred, n_classes):
    cm = np.zeros((n_classes, n_classes), dtype=int)
    for t, p in zip(true, pred):
        cm[int(t), int(p)] += 1
    return cm


def hungarian_match(pred, true, n_classes=None):
    """
    Find best label alignment using Hungarian algorithm.

    Args:
        pred: (N,) predicted labels
        true: (N,) true labels
        n_classes: optional number of classes

    Returns:
        best_acc: accuracy after optimal alignment
        new_pred: remapped prediction aligned to true
        pred2true: dict mapping predicted label -> true label
        true2pred: dict inverse mapping true label -> predicted label
    """
    pred = np.asarray(pred).flatten()
    true = np.asarray(true).flatten()

    if n_classes is None:
        n_classes = int(max(np.max(pred), np.max(true)) + 1)

    cm = _confusion_matrix(true, pred, n_classes)
    # Convert to cost (maximize trace(cm) <=> minimize cost)
    cost = cm.max() - cm
    row_ind, col_ind = linear_sum_assignment(cost)

    # Maps
    pred2true = {int(c): int(r) for r, c in zip(row_ind, col_ind)}
    true2pred = {v: k for k, v in pred2true.items()}

    # Remap predictions to true-space
    new_pred = np.array([pred2true.get(int(p), int(p)) for p in pred], dtype=int)
    best_acc = float(np.mean(new_pred == true))
    return best_acc, new_pred, pred2true, true2pred


# ------------------------
# Public APIs (Hungarian-based)
# ------------------------
def compute_accuracy_spectral(pred, true, n_classes=None):
    """
    Previous version enumerated all permutations (O(k!)).
    This Hungarian version is O(k^3) and scales to k=8+ easily.
    Returns (best_acc, best_pred_remapped).
    """
    if isinstance(pred, torch.Tensor):
        pred = pred.detach().cpu().numpy()
    if isinstance(true, torch.Tensor):
        true = true.detach().cpu().numpy()

    pred = np.asarray(pred).flatten()
    true = np.asarray(true).flatten()

    best_acc, best_pred, _, _ = hungarian_match(pred, true, n_classes)
    return best_acc, best_pred


def compute_acc_ari_nmi(pred, true, n_classes=None):
    """
    Return best-accuracy (Hungarian), best-remapped preds, plus ARI/NMI.
    """
    best_acc, best_pred = compute_accuracy_spectral(pred, true, n_classes)
    ari, nmi = compute_ari_nmi(pred, true)  # permutation-invariant
    return best_acc, best_pred, ari, nmi


def compute_accuracy_multiclass(pred_llh, labels, n_classes):
    """
    Batch accuracy with Hungarian alignment per-sample.

    Args:
        pred_llh: (B, N, K) logits/scores
        labels:   (B, N) ground-truth indices
        n_classes: int
    Returns:
        acc_mean: float
        best_matched_preds: (B, N) np.ndarray of aligned predictions
    """
    # to numpy
    if hasattr(pred_llh, "detach"):
        pred_llh = pred_llh.detach().cpu().numpy()
    else:
        pred_llh = np.asarray(pred_llh)
    if hasattr(labels, "detach"):
        labels = labels.detach().cpu().numpy()
    else:
        labels = np.asarray(labels)

    B, N = labels.shape[0], labels.shape[1]
    pred_labels = np.argmax(pred_llh, axis=2).astype(int)

    best_matched_preds = np.zeros_like(labels, dtype=int)
    acc_total = 0.0
    for i in range(B):
        acc_i, best_pred_i, _, _ = hungarian_match(pred_labels[i], labels[i], n_classes)
        best_matched_preds[i] = best_pred_i
        acc_total += acc_i

    return acc_total / B, best_matched_preds


def compute_nmi_multiclass(pred_llh, labels, average_method: str = "arithmetic"):
    """
    Batch NMI for multiclass clustering (no alignment).

    Args:
        pred_llh: (B, N, K) logits/scores (numpy array or torch.Tensor)
        labels:   (B, N) ground-truth indices (numpy array or torch.Tensor)
        average_method: 'arithmetic' or 'geometric' (sklearn option)

    Returns:
        nmi_mean: float         # 平均 NMI
        nmi_list: np.ndarray    # 每个样本的 NMI (B,)
    """
    import numpy as np
    from sklearn.metrics import normalized_mutual_info_score as nmi_score

    # 转 numpy
    if hasattr(pred_llh, "detach"):
        pred_llh = pred_llh.detach().cpu().numpy()
    else:
        pred_llh = np.asarray(pred_llh)
    if hasattr(labels, "detach"):
        labels = labels.detach().cpu().numpy()
    else:
        labels = np.asarray(labels)

    B = pred_llh.shape[0]
    pred_labels = np.argmax(pred_llh, axis=2).astype(int)

    # 逐样本 NMI
    nmi_list = np.zeros(B, dtype=float)
    for i in range(B):
        nmi_list[i] = nmi_score(labels[i], pred_labels[i], average_method=average_method)

    return float(nmi_list.mean()), nmi_list



def gnn_compute_acc_ari_nmi_multiclass(pred_llh, labels, n_classes):
    """
    Args:
        pred_llh: (B, N, K) logits/scores
        labels:   (B, N) gt labels
    Returns:
        acc_mean, best_matched_preds, ari_mean, nmi_mean, ari_list, nmi_list
    """
    # to numpy
    if hasattr(pred_llh, "detach"):
        pred_llh = pred_llh.detach().cpu().numpy()
    else:
        pred_llh = np.asarray(pred_llh)
    if hasattr(labels, "detach"):
        labels = labels.detach().cpu().numpy()
    else:
        labels = np.asarray(labels)

    B = labels.shape[0]
    pred_labels = np.argmax(pred_llh, axis=2).astype(int)

    best_matched_preds = np.zeros_like(labels, dtype=int)
    acc_total = 0.0
    ari_list, nmi_list = [], []

    for i in range(B):
        # Hungarian alignment for accuracy
        acc_i, best_pred_i, _, _ = hungarian_match(pred_labels[i], labels[i], n_classes)
        best_matched_preds[i] = best_pred_i
        acc_total += acc_i

        # ARI/NMI permutation-invariant
        ari = adjusted_rand_score(labels[i], pred_labels[i])
        nmi = normalized_mutual_info_score(labels[i], pred_labels[i])
        ari_list.append(float(ari))
        nmi_list.append(float(nmi))

    acc_mean = acc_total / B
    ari_mean = float(np.mean(ari_list))
    nmi_mean = float(np.mean(nmi_list))
    return acc_mean, best_matched_preds, ari_mean, nmi_mean


def compute_loss_multiclass(pred_llh, labels, n_classes):
    """
    Cross-entropy with Hungarian-aligned targets (per sample).

    Rationale:
        We first derive an alignment from (argmax logits) vs true using Hungarian,
        then remap true labels into the *logit index space* so CE is meaningful.

    Args:
        pred_llh: (B, N, K) logits
        labels:   (B, N) target indices (un-aligned, canonical)
    Returns:
        loss: scalar tensor
    """
    B = pred_llh.shape[0]
    device = pred_llh.device
    loss_total = 0.0

    # Work on CPU copies to compute alignment, then apply to CE on torch tensors
    with torch.no_grad():
        pred_labels = torch.argmax(pred_llh.detach(), dim=2).cpu().numpy()
        labels_np = labels.detach().cpu().numpy()

        # For each sample, compute true->pred mapping
        true2pred_all = []
        for i in range(B):
            # hungarian_match returns pred->true and its inverse (true->pred)
            _, _, _, true2pred = hungarian_match(pred_labels[i], labels_np[i], n_classes)
            # Build an array map of length n_classes (fallback identity for missing classes)
            t2p_arr = np.arange(n_classes, dtype=int)
            for t, p in true2pred.items():
                t2p_arr[int(t)] = int(p)
            true2pred_all.append(torch.from_numpy(t2p_arr))

    # Move maps to the same device
    true2pred_all = [m.to(device) for m in true2pred_all]

    # Compute CE with remapped targets
    for i in range(B):
        # labels[i] shape: (N,)
        mapped_targets = true2pred_all[i][labels[i].long()].to(device)
        loss_i = criterion(pred_llh[i], mapped_targets.long())
        loss_total = loss_total + loss_i

    return loss_total / B
